""" LCD Implementation """
from typing import Any, Tuple, NamedTuple, List, Dict, Union, Type, Optional, Callable

import gym
import numpy as np

from sb3_jax.common.preprocessing import get_flattened_obs_dim, get_act_dim

from diffgro.environments.collect_dataset import get_skill_embed
from diffgro.lcd.planner import LCDPlanner
from diffgro.lcd.functions import guide_fn_dict
from diffgro.utils import llm
from diffgro.utils import *


class LCD:
    guide_methods = ['blank', 'test', 'manual', 'llm']
    def __init__(
        self,
        env: gym.Env,
        planner: LCDPlanner,
        guide: str = None,      # guide function
        guide_pt: str = None,   # prompt or path for llm guidance
        delta: float = 1.0,     # scale for guidance
        verbose: bool = False,
    ):
        self.env = env
        self.planner = planner.policy
        # guidance
        if guide is not None:
            assert guide in LCD.guide_methods, f"Guide method {guide} should be in {LCD.guide_methods}"
        self.guide = guide
        self.guide_fn = guide_fn_dict[guide] if guide is not None else None
        self.guide_pt = guide_pt # context
        self.context_info = None
        self.delta = delta
        # misc
        self.verbose = verbose

        self._setup()

    def _setup(self) -> None:
        self.obs_dim = get_flattened_obs_dim(self.env.observation_space)
        self.act_dim = get_act_dim(self.env.action_space)
        self.horizon = self.planner.horizon 
        
        # task embedding
        self.task = get_skill_embed(None, self.env.env_name).reshape(1, -1)
        if self.env.domain_name == 'metaworld_complex':
            self.skill = [get_skill_embed(None, task).reshape(1, -1) for task in self.env.full_task_list]

    def _setup_guide(self) -> None:
        # guidance settings 
        self.n_guide_steps = 1 
        if self.guide == 'test':
            self.loss_fn = [self.guide_fn[self.guide_pt] for _ in range(self.env.task_num)]
        if self.guide == 'blank': # no guidance only for evaluating contexts
            self.n_guide_steps = 0 
        if self.guide == 'manual':
            self.loss_fn, self.guide_pt, self.loss_pt = [], [], []
            for context in self.context_info:
                context_dict = {"context_type": context[2], "context_target": context[3]}
                self.loss_pt.append(context[0])
                self.guide_pt.append(context[0])
                loss_fn, _ = self.guide_fn(**context_dict)
                self.loss_fn.append(loss_fn)
                self.delta = context[4]
        if self.guide  == 'llm':
            pass
        print_b(f"[lcd] guidance function is '{self.guide}' and scale is '{self.delta}'")
        print_b(f"[lcd] the guide prompt is '{self.guide_pt}'")

    def reset(self) -> None:
        self.h, self.t = 0, 0
        self.obs_stack = np.zeros((1, self.horizon, self.obs_dim))
    
    def predict(self, obs: np.ndarray, deterministic: bool = True):
        # add batch dimension
        obs = obs.reshape((-1,) + obs.shape)
        
        task, skill = self.task, None
        if self.env.domain_name == 'metaworld_complex':
            task, skill = self.task, self.skill[self.env.success_count]

        # if self.env.domain_name == 'metaworld_complex':
        #    if self.t > 400:
        #        self.guide_fn = None

        # 1. inference high-level policy
        if self.guide_fn is None or self.guide_fn == 'blank':
            plan = self.predict_hact_without_guide(obs, task, skill)
        else:
            plan = self.predict_hact_with_guide(obs, task, skill)
        
        # 2. inference low-level policy
        act = self.planner._predict_lact(obs, plan[:,1,:], skill)
        act = act[0]
        
        self.t += 1

        act = np.array(act.copy())
        return act, None, {"guided": self.guided}
    
    def predict_hact_without_guide(self, obs, task, skill):
        self.guided = False
        plan, info = self.planner._predict_hact(obs, task, skill, delta=None, guide_fn=None, deterministic=True)
        return plan

    def predict_hact_with_guide(self, obs, task, skill):
        self.guided = True
        loss_fn = self.loss_fn[0]
        if self.env.domain_name == 'metaworld_complex':
            try:
                loss_fn = self.loss_fn[self.env.success_count]
            except:
                loss_fn = None
                self.guided = False
        plan, info = self.planner._predict_hact(obs, task, skill, delta=self.delta, guide_fn=loss_fn, deterministic=True)
        return plan
